{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RMSProp算法\n",
    "\n",
    "我们在[“AdaGrad算法”](adagrad.ipynb)一节中提到，因为调整学习率时分母上的变量$\\boldsymbol{s}_t$一直在累加按元素平方的小批量随机梯度，所以目标函数自变量每个元素的学习率在迭代过程中一直在降低（或不变）。因此，当学习率在迭代早期降得较快且当前解依然不佳时，AdaGrad算法在迭代后期由于学习率过小，可能较难找到一个有用的解。为了解决这一问题，RMSProp算法对AdaGrad算法做了一点小小的修改。该算法源自Coursera上的一门课，即“机器学习的神经网络” [1]。\n",
    "\n",
    "## 算法\n",
    "\n",
    "我们在[“动量法”](momentum.ipynb)一节里介绍过指数加权移动平均。不同于AdaGrad算法里状态变量$\\boldsymbol{s}_t$是截至时间步$t$所有小批量随机梯度$\\boldsymbol{g}_t$按元素平方和，RMSProp算法将这些梯度按元素平方做指数加权移动平均。具体来说，给定超参数$0 \\leq \\gamma < 1$，RMSProp算法在时间步$t>0$计算\n",
    "\n",
    "$$\\boldsymbol{s}_t \\leftarrow \\gamma \\boldsymbol{s}_{t-1} + (1 - \\gamma) \\boldsymbol{g}_t \\odot \\boldsymbol{g}_t. $$\n",
    "\n",
    "和AdaGrad算法一样，RMSProp算法将目标函数自变量中每个元素的学习率通过按元素运算重新调整，然后更新自变量\n",
    "\n",
    "$$\\boldsymbol{x}_t \\leftarrow \\boldsymbol{x}_{t-1} - \\frac{\\eta}{\\sqrt{\\boldsymbol{s}_t + \\epsilon}} \\odot \\boldsymbol{g}_t, $$\n",
    "\n",
    "其中$\\eta$是学习率，$\\epsilon$是为了维持数值稳定性而添加的常数，如$10^{-6}$。因为RMSProp算法的状态变量$\\boldsymbol{s}_t$是对平方项$\\boldsymbol{g}_t \\odot \\boldsymbol{g}_t$的指数加权移动平均，所以可以看作最近$1/(1-\\gamma)$个时间步的小批量随机梯度平方项的加权平均。如此一来，自变量每个元素的学习率在迭代过程中就不再一直降低（或不变）。\n",
    "\n",
    "照例，让我们先观察RMSProp算法对目标函数$f(\\boldsymbol{x})=0.1x_1^2+2x_2^2$中自变量的迭代轨迹。回忆在[“AdaGrad算法”](adagrad.ipynb)一节使用的学习率为0.4的AdaGrad算法，自变量在迭代后期的移动幅度较小。但在同样的学习率下，RMSProp算法可以更快逼近最优解。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "3"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 20, x1 -0.010599, x2 0.000000\n"
     ]
    },
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       "  \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Created with matplotlib (http://matplotlib.org/) -->\n",
       "<svg height=\"184.455469pt\" version=\"1.1\" viewBox=\"0 0 248.620313 184.455469\" width=\"248.620313pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       " <defs>\n",
       "  <style type=\"text/css\">\n",
       "*{stroke-linecap:butt;stroke-linejoin:round;}\n",
       "  </style>\n",
       " </defs>\n",
       " <g id=\"figure_1\">\n",
       "  <g id=\"patch_1\">\n",
       "   <path d=\"M 0 184.455469 \n",
       "L 248.620313 184.455469 \n",
       "L 248.620313 0 \n",
       "L 0 0 \n",
       "z\n",
       "\" style=\"fill:none;\"/>\n",
       "  </g>\n",
       "  <g id=\"axes_1\">\n",
       "   <g id=\"patch_2\">\n",
       "    <path d=\"M 42.620312 146.899219 \n",
       "L 237.920313 146.899219 \n",
       "L 237.920313 10.999219 \n",
       "L 42.620312 10.999219 \n",
       "z\n",
       "\" style=\"fill:#ffffff;\"/>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_1\">\n",
       "    <g id=\"xtick_1\">\n",
       "     <g id=\"line2d_1\">\n",
       "      <defs>\n",
       "       <path d=\"M 0 0 \n",
       "L 0 3.5 \n",
       "\" id=\"mb99b4acc94\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"88.39375\" xlink:href=\"#mb99b4acc94\" y=\"146.899219\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_1\">\n",
       "      <!-- −4 -->\n",
       "      <defs>\n",
       "       <path d=\"M 10.59375 35.5 \n",
       "L 73.1875 35.5 \n",
       "L 73.1875 27.203125 \n",
       "L 10.59375 27.203125 \n",
       "z\n",
       "\" id=\"DejaVuSans-2212\"/>\n",
       "       <path d=\"M 37.796875 64.3125 \n",
       "L 12.890625 25.390625 \n",
       "L 37.796875 25.390625 \n",
       "z\n",
       "M 35.203125 72.90625 \n",
       "L 47.609375 72.90625 \n",
       "L 47.609375 25.390625 \n",
       "L 58.015625 25.390625 \n",
       "L 58.015625 17.1875 \n",
       "L 47.609375 17.1875 \n",
       "L 47.609375 0 \n",
       "L 37.796875 0 \n",
       "L 37.796875 17.1875 \n",
       "L 4.890625 17.1875 \n",
       "L 4.890625 26.703125 \n",
       "z\n",
       "\" id=\"DejaVuSans-34\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(81.022656 161.497656)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-2212\"/>\n",
       "       <use x=\"83.789062\" xlink:href=\"#DejaVuSans-34\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_2\">\n",
       "     <g id=\"line2d_2\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"149.425\" xlink:href=\"#mb99b4acc94\" y=\"146.899219\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_2\">\n",
       "      <!-- −2 -->\n",
       "      <defs>\n",
       "       <path d=\"M 19.1875 8.296875 \n",
       "L 53.609375 8.296875 \n",
       "L 53.609375 0 \n",
       "L 7.328125 0 \n",
       "L 7.328125 8.296875 \n",
       "Q 12.9375 14.109375 22.625 23.890625 \n",
       "Q 32.328125 33.6875 34.8125 36.53125 \n",
       "Q 39.546875 41.84375 41.421875 45.53125 \n",
       "Q 43.3125 49.21875 43.3125 52.78125 \n",
       "Q 43.3125 58.59375 39.234375 62.25 \n",
       "Q 35.15625 65.921875 28.609375 65.921875 \n",
       "Q 23.96875 65.921875 18.8125 64.3125 \n",
       "Q 13.671875 62.703125 7.8125 59.421875 \n",
       "L 7.8125 69.390625 \n",
       "Q 13.765625 71.78125 18.9375 73 \n",
       "Q 24.125 74.21875 28.421875 74.21875 \n",
       "Q 39.75 74.21875 46.484375 68.546875 \n",
       "Q 53.21875 62.890625 53.21875 53.421875 \n",
       "Q 53.21875 48.921875 51.53125 44.890625 \n",
       "Q 49.859375 40.875 45.40625 35.40625 \n",
       "Q 44.1875 33.984375 37.640625 27.21875 \n",
       "Q 31.109375 20.453125 19.1875 8.296875 \n",
       "z\n",
       "\" id=\"DejaVuSans-32\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(142.053906 161.497656)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-2212\"/>\n",
       "       <use x=\"83.789062\" xlink:href=\"#DejaVuSans-32\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_3\">\n",
       "     <g id=\"line2d_3\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"210.45625\" xlink:href=\"#mb99b4acc94\" y=\"146.899219\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_3\">\n",
       "      <!-- 0 -->\n",
       "      <defs>\n",
       "       <path d=\"M 31.78125 66.40625 \n",
       "Q 24.171875 66.40625 20.328125 58.90625 \n",
       "Q 16.5 51.421875 16.5 36.375 \n",
       "Q 16.5 21.390625 20.328125 13.890625 \n",
       "Q 24.171875 6.390625 31.78125 6.390625 \n",
       "Q 39.453125 6.390625 43.28125 13.890625 \n",
       "Q 47.125 21.390625 47.125 36.375 \n",
       "Q 47.125 51.421875 43.28125 58.90625 \n",
       "Q 39.453125 66.40625 31.78125 66.40625 \n",
       "z\n",
       "M 31.78125 74.21875 \n",
       "Q 44.046875 74.21875 50.515625 64.515625 \n",
       "Q 56.984375 54.828125 56.984375 36.375 \n",
       "Q 56.984375 17.96875 50.515625 8.265625 \n",
       "Q 44.046875 -1.421875 31.78125 -1.421875 \n",
       "Q 19.53125 -1.421875 13.0625 8.265625 \n",
       "Q 6.59375 17.96875 6.59375 36.375 \n",
       "Q 6.59375 54.828125 13.0625 64.515625 \n",
       "Q 19.53125 74.21875 31.78125 74.21875 \n",
       "z\n",
       "\" id=\"DejaVuSans-30\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(207.275 161.497656)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"text_4\">\n",
       "     <!-- x1 -->\n",
       "     <defs>\n",
       "      <path d=\"M 54.890625 54.6875 \n",
       "L 35.109375 28.078125 \n",
       "L 55.90625 0 \n",
       "L 45.3125 0 \n",
       "L 29.390625 21.484375 \n",
       "L 13.484375 0 \n",
       "L 2.875 0 \n",
       "L 24.125 28.609375 \n",
       "L 4.6875 54.6875 \n",
       "L 15.28125 54.6875 \n",
       "L 29.78125 35.203125 \n",
       "L 44.28125 54.6875 \n",
       "z\n",
       "\" id=\"DejaVuSans-78\"/>\n",
       "      <path d=\"M 12.40625 8.296875 \n",
       "L 28.515625 8.296875 \n",
       "L 28.515625 63.921875 \n",
       "L 10.984375 60.40625 \n",
       "L 10.984375 69.390625 \n",
       "L 28.421875 72.90625 \n",
       "L 38.28125 72.90625 \n",
       "L 38.28125 8.296875 \n",
       "L 54.390625 8.296875 \n",
       "L 54.390625 0 \n",
       "L 12.40625 0 \n",
       "z\n",
       "\" id=\"DejaVuSans-31\"/>\n",
       "     </defs>\n",
       "     <g transform=\"translate(134.129687 175.175781)scale(0.1 -0.1)\">\n",
       "      <use xlink:href=\"#DejaVuSans-78\"/>\n",
       "      <use x=\"59.179688\" xlink:href=\"#DejaVuSans-31\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_2\">\n",
       "    <g id=\"ytick_1\">\n",
       "     <g id=\"line2d_4\">\n",
       "      <defs>\n",
       "       <path d=\"M 0 0 \n",
       "L -3.5 0 \n",
       "\" id=\"m5524e19152\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"42.620312\" xlink:href=\"#m5524e19152\" y=\"146.899219\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_5\">\n",
       "      <!-- −3 -->\n",
       "      <defs>\n",
       "       <path d=\"M 40.578125 39.3125 \n",
       "Q 47.65625 37.796875 51.625 33 \n",
       "Q 55.609375 28.21875 55.609375 21.1875 \n",
       "Q 55.609375 10.40625 48.1875 4.484375 \n",
       "Q 40.765625 -1.421875 27.09375 -1.421875 \n",
       "Q 22.515625 -1.421875 17.65625 -0.515625 \n",
       "Q 12.796875 0.390625 7.625 2.203125 \n",
       "L 7.625 11.71875 \n",
       "Q 11.71875 9.328125 16.59375 8.109375 \n",
       "Q 21.484375 6.890625 26.8125 6.890625 \n",
       "Q 36.078125 6.890625 40.9375 10.546875 \n",
       "Q 45.796875 14.203125 45.796875 21.1875 \n",
       "Q 45.796875 27.640625 41.28125 31.265625 \n",
       "Q 36.765625 34.90625 28.71875 34.90625 \n",
       "L 20.21875 34.90625 \n",
       "L 20.21875 43.015625 \n",
       "L 29.109375 43.015625 \n",
       "Q 36.375 43.015625 40.234375 45.921875 \n",
       "Q 44.09375 48.828125 44.09375 54.296875 \n",
       "Q 44.09375 59.90625 40.109375 62.90625 \n",
       "Q 36.140625 65.921875 28.71875 65.921875 \n",
       "Q 24.65625 65.921875 20.015625 65.03125 \n",
       "Q 15.375 64.15625 9.8125 62.3125 \n",
       "L 9.8125 71.09375 \n",
       "Q 15.4375 72.65625 20.34375 73.4375 \n",
       "Q 25.25 74.21875 29.59375 74.21875 \n",
       "Q 40.828125 74.21875 47.359375 69.109375 \n",
       "Q 53.90625 64.015625 53.90625 55.328125 \n",
       "Q 53.90625 49.265625 50.4375 45.09375 \n",
       "Q 46.96875 40.921875 40.578125 39.3125 \n",
       "z\n",
       "\" id=\"DejaVuSans-33\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(20.878125 150.698437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-2212\"/>\n",
       "       <use x=\"83.789062\" xlink:href=\"#DejaVuSans-33\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_2\">\n",
       "     <g id=\"line2d_5\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"42.620312\" xlink:href=\"#m5524e19152\" y=\"112.924219\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_6\">\n",
       "      <!-- −2 -->\n",
       "      <g transform=\"translate(20.878125 116.723437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-2212\"/>\n",
       "       <use x=\"83.789062\" xlink:href=\"#DejaVuSans-32\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_3\">\n",
       "     <g id=\"line2d_6\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"42.620312\" xlink:href=\"#m5524e19152\" y=\"78.949219\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_7\">\n",
       "      <!-- −1 -->\n",
       "      <g transform=\"translate(20.878125 82.748437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-2212\"/>\n",
       "       <use x=\"83.789062\" xlink:href=\"#DejaVuSans-31\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_4\">\n",
       "     <g id=\"line2d_7\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"42.620312\" xlink:href=\"#m5524e19152\" y=\"44.974219\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_8\">\n",
       "      <!-- 0 -->\n",
       "      <g transform=\"translate(29.257812 48.773437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_5\">\n",
       "     <g id=\"line2d_8\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"42.620312\" xlink:href=\"#m5524e19152\" y=\"10.999219\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_9\">\n",
       "      <!-- 1 -->\n",
       "      <g transform=\"translate(29.257812 14.798437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-31\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"text_10\">\n",
       "     <!-- x2 -->\n",
       "     <g transform=\"translate(14.798437 85.089844)rotate(-90)scale(0.1 -0.1)\">\n",
       "      <use xlink:href=\"#DejaVuSans-78\"/>\n",
       "      <use x=\"59.179688\" xlink:href=\"#DejaVuSans-32\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"LineCollection_1\">\n",
       "    <path clip-path=\"url(#p554c127ca9)\" d=\"M 237.920313 86.009224 \n",
       "L 234.86875 86.124739 \n",
       "L 231.817188 86.226664 \n",
       "L 228.765625 86.314999 \n",
       "L 225.714063 86.389744 \n",
       "L 222.6625 86.450899 \n",
       "L 219.610938 86.498464 \n",
       "L 216.559375 86.532439 \n",
       "L 213.507813 86.552824 \n",
       "L 210.45625 86.559619 \n",
       "L 207.404688 86.552824 \n",
       "L 204.353125 86.532439 \n",
       "L 201.301563 86.498464 \n",
       "L 198.25 86.450899 \n",
       "L 195.198438 86.389744 \n",
       "L 192.146875 86.314999 \n",
       "L 189.095313 86.226664 \n",
       "L 186.04375 86.124739 \n",
       "L 182.992188 86.009224 \n",
       "L 179.940625 85.880119 \n",
       "L 177.034375 85.744219 \n",
       "L 176.889063 85.736833 \n",
       "L 173.8375 85.566958 \n",
       "L 170.785938 85.382311 \n",
       "L 167.734375 85.182893 \n",
       "L 164.682813 84.968702 \n",
       "L 161.63125 84.73974 \n",
       "L 158.579688 84.496007 \n",
       "L 155.528125 84.237501 \n",
       "L 152.476563 83.964224 \n",
       "L 149.425 83.676175 \n",
       "L 146.373438 83.373355 \n",
       "L 143.321875 83.055762 \n",
       "L 140.270313 82.723398 \n",
       "L 137.21875 82.376262 \n",
       "L 136.969643 82.346719 \n",
       "L 134.167188 81.982701 \n",
       "L 131.115625 81.570147 \n",
       "L 128.064063 81.141415 \n",
       "L 125.0125 80.696504 \n",
       "L 121.960938 80.235415 \n",
       "L 118.909375 79.758147 \n",
       "L 115.857813 79.264701 \n",
       "L 113.96875 78.949219 \n",
       "L 112.80625 78.73464 \n",
       "L 109.754688 78.153488 \n",
       "L 106.703125 77.554456 \n",
       "L 103.651563 76.937541 \n",
       "L 100.6 76.302745 \n",
       "L 97.548438 75.650067 \n",
       "L 97.100875 75.551719 \n",
       "L 94.496875 74.912189 \n",
       "L 91.445313 74.142756 \n",
       "L 88.39375 73.353336 \n",
       "L 85.342188 72.543932 \n",
       "L 83.908321 72.154219 \n",
       "L 82.290625 71.655919 \n",
       "L 79.239063 70.693294 \n",
       "L 76.1875 69.708019 \n",
       "L 73.307374 68.756719 \n",
       "L 73.135938 68.691382 \n",
       "L 70.084375 67.502257 \n",
       "L 67.032813 66.286998 \n",
       "L 64.752171 65.359219 \n",
       "L 63.98125 64.988582 \n",
       "L 60.929688 63.490594 \n",
       "L 57.878125 61.961719 \n",
       "L 57.878125 61.961719 \n",
       "L 54.826563 60.055344 \n",
       "L 52.486044 58.564219 \n",
       "L 51.775 57.98179 \n",
       "L 48.723438 55.433665 \n",
       "L 48.409725 55.166719 \n",
       "L 45.671875 51.905119 \n",
       "L 45.559891 51.769219 \n",
       "L 43.880132 48.371719 \n",
       "L 43.320212 44.974219 \n",
       "L 43.880132 41.576719 \n",
       "L 45.559891 38.179219 \n",
       "L 45.671875 38.043319 \n",
       "L 48.409725 34.781719 \n",
       "L 48.723438 34.514772 \n",
       "L 51.775 31.966647 \n",
       "L 52.486044 31.384219 \n",
       "L 54.826563 29.893094 \n",
       "L 57.878125 27.986719 \n",
       "L 57.878125 27.986719 \n",
       "L 60.929688 26.457844 \n",
       "L 63.98125 24.959855 \n",
       "L 64.752171 24.589219 \n",
       "L 67.032813 23.66144 \n",
       "L 70.084375 22.44618 \n",
       "L 73.135938 21.257055 \n",
       "L 73.307374 21.191719 \n",
       "L 76.1875 20.240419 \n",
       "L 79.239063 19.255144 \n",
       "L 82.290625 18.292519 \n",
       "L 83.908321 17.794219 \n",
       "L 85.342188 17.404506 \n",
       "L 88.39375 16.595101 \n",
       "L 91.445313 15.805682 \n",
       "L 94.496875 15.036248 \n",
       "L 97.100875 14.396719 \n",
       "L 97.548438 14.29837 \n",
       "L 100.6 13.645692 \n",
       "L 103.651563 13.010896 \n",
       "L 106.703125 12.393982 \n",
       "L 109.754688 11.794949 \n",
       "L 112.80625 11.213798 \n",
       "L 113.96875 10.999219 \n",
       "\" style=\"fill:none;stroke:#1f77b4;stroke-width:1.5;\"/>\n",
       "   </g>\n",
       "   <g id=\"LineCollection_2\">\n",
       "    <path clip-path=\"url(#p554c127ca9)\" d=\"M 237.920313 103.406365 \n",
       "L 234.86875 103.488876 \n",
       "L 231.817188 103.561679 \n",
       "L 228.765625 103.624776 \n",
       "L 225.714063 103.678165 \n",
       "L 222.6625 103.721847 \n",
       "L 219.610938 103.755822 \n",
       "L 216.559375 103.78009 \n",
       "L 213.507813 103.794651 \n",
       "L 210.45625 103.799504 \n",
       "L 207.404688 103.794651 \n",
       "L 204.353125 103.78009 \n",
       "L 201.301563 103.755822 \n",
       "L 198.25 103.721847 \n",
       "L 195.198438 103.678165 \n",
       "L 192.146875 103.624776 \n",
       "L 189.095313 103.561679 \n",
       "L 186.04375 103.488876 \n",
       "L 182.992188 103.406365 \n",
       "L 179.940625 103.314147 \n",
       "L 176.889063 103.212222 \n",
       "L 173.8375 103.10059 \n",
       "L 170.785938 102.979251 \n",
       "L 167.734375 102.848204 \n",
       "L 165.208944 102.731719 \n",
       "L 164.682813 102.70598 \n",
       "L 161.63125 102.546401 \n",
       "L 158.579688 102.376526 \n",
       "L 155.528125 102.196355 \n",
       "L 152.476563 102.005889 \n",
       "L 149.425 101.805128 \n",
       "L 146.373438 101.594071 \n",
       "L 143.321875 101.372719 \n",
       "L 140.270313 101.141071 \n",
       "L 137.21875 100.899128 \n",
       "L 134.167188 100.646889 \n",
       "L 131.115625 100.384355 \n",
       "L 128.064063 100.111526 \n",
       "L 125.0125 99.828401 \n",
       "L 121.960938 99.53498 \n",
       "L 119.943803 99.334219 \n",
       "L 118.909375 99.224622 \n",
       "L 115.857813 98.890352 \n",
       "L 112.80625 98.545122 \n",
       "L 109.754688 98.188932 \n",
       "L 106.703125 97.821783 \n",
       "L 103.651563 97.443674 \n",
       "L 100.6 97.054606 \n",
       "L 97.548438 96.654578 \n",
       "L 94.496875 96.24359 \n",
       "L 92.277557 95.936719 \n",
       "L 91.445313 95.813706 \n",
       "L 88.39375 95.350943 \n",
       "L 85.342188 94.876464 \n",
       "L 82.290625 94.39027 \n",
       "L 79.239063 93.892361 \n",
       "L 76.1875 93.382736 \n",
       "L 73.135938 92.861395 \n",
       "L 71.291587 92.539219 \n",
       "L 70.084375 92.312719 \n",
       "L 67.032813 91.727594 \n",
       "L 63.98125 91.129885 \n",
       "L 60.929688 90.519594 \n",
       "L 57.878125 89.896719 \n",
       "L 54.826563 89.26126 \n",
       "L 54.263653 89.141719 \n",
       "L 51.775 88.570939 \n",
       "L 48.723438 87.857464 \n",
       "L 45.671875 87.130399 \n",
       "L 42.620313 86.389744 \n",
       "\" style=\"fill:none;stroke:#1f77b4;stroke-width:1.5;\"/>\n",
       "   </g>\n",
       "   <g id=\"LineCollection_3\">\n",
       "    <path clip-path=\"url(#p554c127ca9)\" d=\"M 237.920313 116.712826 \n",
       "L 234.86875 116.779986 \n",
       "L 231.817188 116.839245 \n",
       "L 228.765625 116.890602 \n",
       "L 225.714063 116.934059 \n",
       "L 222.6625 116.969614 \n",
       "L 219.610938 116.997268 \n",
       "L 216.559375 117.017021 \n",
       "L 213.507813 117.028873 \n",
       "L 210.45625 117.032823 \n",
       "L 207.404688 117.028873 \n",
       "L 204.353125 117.017021 \n",
       "L 201.301563 116.997268 \n",
       "L 198.25 116.969614 \n",
       "L 195.198438 116.934059 \n",
       "L 192.146875 116.890602 \n",
       "L 189.095313 116.839245 \n",
       "L 186.04375 116.779986 \n",
       "L 182.992188 116.712826 \n",
       "L 179.940625 116.637765 \n",
       "L 176.889063 116.554803 \n",
       "L 173.8375 116.46394 \n",
       "L 170.785938 116.365175 \n",
       "L 169.542708 116.321719 \n",
       "L 167.734375 116.255426 \n",
       "L 164.682813 116.135271 \n",
       "L 161.63125 116.006829 \n",
       "L 158.579688 115.8701 \n",
       "L 155.528125 115.725085 \n",
       "L 152.476563 115.571783 \n",
       "L 149.425 115.410194 \n",
       "L 146.373438 115.240319 \n",
       "L 143.321875 115.062158 \n",
       "L 140.270313 114.87571 \n",
       "L 137.21875 114.680975 \n",
       "L 134.167188 114.477954 \n",
       "L 131.115625 114.266646 \n",
       "L 128.064063 114.047051 \n",
       "L 125.0125 113.81917 \n",
       "L 121.960938 113.583002 \n",
       "L 118.909375 113.338548 \n",
       "L 115.857813 113.085807 \n",
       "L 113.96875 112.924219 \n",
       "L 112.80625 112.81968 \n",
       "L 109.754688 112.536555 \n",
       "L 106.703125 112.244719 \n",
       "L 103.651563 111.944171 \n",
       "L 100.6 111.634911 \n",
       "L 97.548438 111.31694 \n",
       "L 94.496875 110.990257 \n",
       "L 91.445313 110.654863 \n",
       "L 88.39375 110.310757 \n",
       "L 85.342188 109.95794 \n",
       "L 82.290625 109.596411 \n",
       "L 81.716213 109.526719 \n",
       "L 79.239063 109.209925 \n",
       "L 76.1875 108.810489 \n",
       "L 73.135938 108.401871 \n",
       "L 70.084375 107.98407 \n",
       "L 67.032813 107.557087 \n",
       "L 63.98125 107.120921 \n",
       "L 60.929688 106.675573 \n",
       "L 57.878125 106.221043 \n",
       "L 57.273855 106.129219 \n",
       "L 54.826563 105.736079 \n",
       "L 51.775 105.236162 \n",
       "L 48.723438 104.726537 \n",
       "L 45.671875 104.207204 \n",
       "L 42.620313 103.678165 \n",
       "\" style=\"fill:none;stroke:#1f77b4;stroke-width:1.5;\"/>\n",
       "   </g>\n",
       "   <g id=\"LineCollection_4\">\n",
       "    <path clip-path=\"url(#p554c127ca9)\" d=\"M 237.920313 127.897487 \n",
       "L 234.86875 127.956423 \n",
       "L 231.817188 128.008425 \n",
       "L 228.765625 128.053494 \n",
       "L 225.714063 128.091629 \n",
       "L 222.6625 128.122831 \n",
       "L 219.610938 128.147099 \n",
       "L 216.559375 128.164433 \n",
       "L 213.507813 128.174834 \n",
       "L 210.45625 128.1783 \n",
       "L 207.404688 128.174834 \n",
       "L 204.353125 128.164433 \n",
       "L 201.301563 128.147099 \n",
       "L 198.25 128.122831 \n",
       "L 195.198438 128.091629 \n",
       "L 192.146875 128.053494 \n",
       "L 189.095313 128.008425 \n",
       "L 186.04375 127.956423 \n",
       "L 182.992188 127.897487 \n",
       "L 179.940625 127.831617 \n",
       "L 176.889063 127.758813 \n",
       "L 173.8375 127.679076 \n",
       "L 170.785938 127.592405 \n",
       "L 167.734375 127.4988 \n",
       "L 164.682813 127.398262 \n",
       "L 161.63125 127.29079 \n",
       "L 158.579688 127.176385 \n",
       "L 155.528125 127.055045 \n",
       "L 152.476563 126.926772 \n",
       "L 149.425 126.791566 \n",
       "L 146.373438 126.649425 \n",
       "L 143.605741 126.514219 \n",
       "L 143.321875 126.499761 \n",
       "L 140.270313 126.337115 \n",
       "L 137.21875 126.16724 \n",
       "L 134.167188 125.990136 \n",
       "L 131.115625 125.805804 \n",
       "L 128.064063 125.614243 \n",
       "L 125.0125 125.415453 \n",
       "L 121.960938 125.209434 \n",
       "L 118.909375 124.996187 \n",
       "L 115.857813 124.775711 \n",
       "L 112.80625 124.548006 \n",
       "L 109.754688 124.313072 \n",
       "L 106.703125 124.07091 \n",
       "L 103.651563 123.821519 \n",
       "L 100.6 123.5649 \n",
       "L 97.548438 123.301051 \n",
       "L 95.473375 123.116719 \n",
       "L 94.496875 123.026119 \n",
       "L 91.445313 122.735444 \n",
       "L 88.39375 122.437219 \n",
       "L 85.342188 122.131444 \n",
       "L 82.290625 121.818119 \n",
       "L 79.239063 121.497244 \n",
       "L 76.1875 121.168819 \n",
       "L 73.135938 120.832844 \n",
       "L 70.084375 120.489319 \n",
       "L 67.032813 120.138244 \n",
       "L 63.98125 119.779619 \n",
       "L 63.477899 119.719219 \n",
       "L 60.929688 119.399222 \n",
       "L 57.878125 119.008114 \n",
       "L 54.826563 118.609105 \n",
       "L 51.775 118.202195 \n",
       "L 48.723438 117.787384 \n",
       "L 45.671875 117.364672 \n",
       "L 42.620313 116.934059 \n",
       "\" style=\"fill:none;stroke:#1f77b4;stroke-width:1.5;\"/>\n",
       "   </g>\n",
       "   <g id=\"LineCollection_5\">\n",
       "    <path clip-path=\"url(#p554c127ca9)\" d=\"M 237.920313 137.753766 \n",
       "L 234.86875 137.806273 \n",
       "L 231.817188 137.852603 \n",
       "L 228.765625 137.892755 \n",
       "L 225.714063 137.92673 \n",
       "L 222.6625 137.954528 \n",
       "L 219.610938 137.976148 \n",
       "L 216.559375 137.991591 \n",
       "L 213.507813 138.000857 \n",
       "L 210.45625 138.003946 \n",
       "L 207.404688 138.000857 \n",
       "L 204.353125 137.991591 \n",
       "L 201.301563 137.976148 \n",
       "L 198.25 137.954528 \n",
       "L 195.198438 137.92673 \n",
       "L 192.146875 137.892755 \n",
       "L 189.095313 137.852603 \n",
       "L 186.04375 137.806273 \n",
       "L 182.992188 137.753766 \n",
       "L 179.940625 137.695082 \n",
       "L 176.889063 137.630221 \n",
       "L 173.8375 137.559182 \n",
       "L 170.785938 137.481966 \n",
       "L 167.734375 137.398573 \n",
       "L 164.682813 137.309003 \n",
       "L 161.63125 137.213255 \n",
       "L 158.579688 137.11133 \n",
       "L 155.528125 137.003228 \n",
       "L 152.476563 136.888948 \n",
       "L 149.425 136.768491 \n",
       "L 147.936433 136.706719 \n",
       "L 146.373438 136.63941 \n",
       "L 143.321875 136.501587 \n",
       "L 140.270313 136.357353 \n",
       "L 137.21875 136.206709 \n",
       "L 134.167188 136.049655 \n",
       "L 131.115625 135.88619 \n",
       "L 128.064063 135.716315 \n",
       "L 125.0125 135.54003 \n",
       "L 121.960938 135.357334 \n",
       "L 118.909375 135.168228 \n",
       "L 115.857813 134.972712 \n",
       "L 112.80625 134.770785 \n",
       "L 109.754688 134.562448 \n",
       "L 106.703125 134.3477 \n",
       "L 103.651563 134.126542 \n",
       "L 100.6 133.898973 \n",
       "L 97.548438 133.664995 \n",
       "L 94.496875 133.424606 \n",
       "L 93.07017 133.309219 \n",
       "L 91.445313 133.172653 \n",
       "L 88.39375 132.909513 \n",
       "L 85.342188 132.639711 \n",
       "L 82.290625 132.363248 \n",
       "L 79.239063 132.080123 \n",
       "L 76.1875 131.790336 \n",
       "L 73.135938 131.493888 \n",
       "L 70.084375 131.190778 \n",
       "L 67.032813 130.881006 \n",
       "L 63.98125 130.564572 \n",
       "L 60.929688 130.241476 \n",
       "L 57.878125 129.911719 \n",
       "L 57.878125 129.911719 \n",
       "L 54.826563 129.561568 \n",
       "L 51.775 129.204484 \n",
       "L 48.723438 128.840466 \n",
       "L 45.671875 128.469515 \n",
       "L 42.620313 128.091629 \n",
       "\" style=\"fill:none;stroke:#1f77b4;stroke-width:1.5;\"/>\n",
       "   </g>\n",
       "   <g id=\"LineCollection_6\">\n",
       "    <path clip-path=\"url(#p554c127ca9)\" d=\"M 210.45625 146.899219 \n",
       "L 207.404688 146.89634 \n",
       "L 204.353125 146.887702 \n",
       "L 201.301563 146.873306 \n",
       "L 198.25 146.853151 \n",
       "L 195.198438 146.827238 \n",
       "L 192.146875 146.795566 \n",
       "L 189.095313 146.758136 \n",
       "L 186.04375 146.714948 \n",
       "L 182.992188 146.666001 \n",
       "L 179.940625 146.611295 \n",
       "L 176.889063 146.550831 \n",
       "L 173.8375 146.484609 \n",
       "L 170.785938 146.412628 \n",
       "L 167.734375 146.334888 \n",
       "L 164.682813 146.25139 \n",
       "L 161.63125 146.162134 \n",
       "L 158.579688 146.067119 \n",
       "L 155.528125 145.966346 \n",
       "L 152.476563 145.859814 \n",
       "L 149.425 145.747524 \n",
       "L 146.373438 145.629475 \n",
       "L 143.321875 145.505668 \n",
       "L 140.270313 145.376102 \n",
       "L 137.21875 145.240778 \n",
       "L 134.167188 145.099695 \n",
       "L 131.115625 144.952854 \n",
       "L 128.064063 144.800255 \n",
       "L 125.0125 144.641897 \n",
       "L 121.960938 144.47778 \n",
       "L 118.909375 144.307905 \n",
       "L 115.857813 144.132272 \n",
       "L 112.80625 143.95088 \n",
       "L 109.754688 143.763729 \n",
       "L 106.703125 143.57082 \n",
       "L 105.641712 143.501719 \n",
       "L 103.651563 143.367607 \n",
       "L 100.6 143.156008 \n",
       "L 97.548438 142.938449 \n",
       "L 94.496875 142.714929 \n",
       "L 91.445313 142.485449 \n",
       "L 88.39375 142.250008 \n",
       "L 85.342188 142.008607 \n",
       "L 82.290625 141.761245 \n",
       "L 79.239063 141.507923 \n",
       "L 76.1875 141.24864 \n",
       "L 73.135938 140.983396 \n",
       "L 70.084375 140.712192 \n",
       "L 67.032813 140.435028 \n",
       "L 63.98125 140.151903 \n",
       "L 63.477899 140.104219 \n",
       "L 60.929688 139.854039 \n",
       "L 57.878125 139.548264 \n",
       "L 54.826563 139.236312 \n",
       "L 51.775 138.918182 \n",
       "L 48.723438 138.593876 \n",
       "L 45.671875 138.263391 \n",
       "L 42.620313 137.92673 \n",
       "\" style=\"fill:none;stroke:#1f77b4;stroke-width:1.5;\"/>\n",
       "    <path clip-path=\"url(#p554c127ca9)\" d=\"M 237.920313 146.666001 \n",
       "L 234.86875 146.714948 \n",
       "L 231.817188 146.758136 \n",
       "L 228.765625 146.795566 \n",
       "L 225.714063 146.827238 \n",
       "L 222.6625 146.853151 \n",
       "L 219.610938 146.873306 \n",
       "L 216.559375 146.887702 \n",
       "L 213.507813 146.89634 \n",
       "L 210.45625 146.899219 \n",
       "\" style=\"fill:none;stroke:#1f77b4;stroke-width:1.5;\"/>\n",
       "   </g>\n",
       "   <g id=\"LineCollection_7\">\n",
       "    <path clip-path=\"url(#p554c127ca9)\" d=\"M 43.320212 146.899219 \n",
       "L 42.620313 146.827238 \n",
       "\" style=\"fill:none;stroke:#1f77b4;stroke-width:1.5;\"/>\n",
       "   </g>\n",
       "   <g id=\"line2d_9\">\n",
       "    <path clip-path=\"url(#p554c127ca9)\" d=\"M 57.878125 112.924219 \n",
       "L 96.477484 69.948869 \n",
       "L 120.357142 54.42355 \n",
       "L 138.043261 48.294977 \n",
       "L 152.011133 46.027885 \n",
       "L 163.368487 45.269751 \n",
       "L 172.726061 45.04561 \n",
       "L 180.468975 44.988536 \n",
       "L 186.866078 44.976471 \n",
       "L 192.121325 44.97447 \n",
       "L 196.40025 44.974235 \n",
       "L 199.844228 44.974219 \n",
       "L 202.578131 44.974219 \n",
       "L 204.714223 44.974219 \n",
       "L 206.353865 44.974219 \n",
       "L 207.588039 44.974219 \n",
       "L 208.497343 44.974219 \n",
       "L 209.15187 44.974219 \n",
       "L 209.611243 44.974219 \n",
       "L 209.924922 44.974219 \n",
       "L 210.132825 44.974219 \n",
       "\" style=\"fill:none;stroke:#ff7f0e;stroke-linecap:square;stroke-width:1.5;\"/>\n",
       "    <defs>\n",
       "     <path d=\"M 0 3 \n",
       "C 0.795609 3 1.55874 2.683901 2.12132 2.12132 \n",
       "C 2.683901 1.55874 3 0.795609 3 0 \n",
       "C 3 -0.795609 2.683901 -1.55874 2.12132 -2.12132 \n",
       "C 1.55874 -2.683901 0.795609 -3 0 -3 \n",
       "C -0.795609 -3 -1.55874 -2.683901 -2.12132 -2.12132 \n",
       "C -2.683901 -1.55874 -3 -0.795609 -3 0 \n",
       "C -3 0.795609 -2.683901 1.55874 -2.12132 2.12132 \n",
       "C -1.55874 2.683901 -0.795609 3 0 3 \n",
       "z\n",
       "\" id=\"mbe00751a98\" style=\"stroke:#ff7f0e;\"/>\n",
       "    </defs>\n",
       "    <g clip-path=\"url(#p554c127ca9)\">\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"57.878125\" xlink:href=\"#mbe00751a98\" y=\"112.924219\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"96.477484\" xlink:href=\"#mbe00751a98\" y=\"69.948869\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"120.357142\" xlink:href=\"#mbe00751a98\" y=\"54.42355\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"138.043261\" xlink:href=\"#mbe00751a98\" y=\"48.294977\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"152.011133\" xlink:href=\"#mbe00751a98\" y=\"46.027885\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"163.368487\" xlink:href=\"#mbe00751a98\" y=\"45.269751\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"172.726061\" xlink:href=\"#mbe00751a98\" y=\"45.04561\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"180.468975\" xlink:href=\"#mbe00751a98\" y=\"44.988536\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"186.866078\" xlink:href=\"#mbe00751a98\" y=\"44.976471\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"192.121325\" xlink:href=\"#mbe00751a98\" y=\"44.97447\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"196.40025\" xlink:href=\"#mbe00751a98\" y=\"44.974235\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"199.844228\" xlink:href=\"#mbe00751a98\" y=\"44.974219\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"202.578131\" xlink:href=\"#mbe00751a98\" y=\"44.974219\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"204.714223\" xlink:href=\"#mbe00751a98\" y=\"44.974219\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"206.353865\" xlink:href=\"#mbe00751a98\" y=\"44.974219\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"207.588039\" xlink:href=\"#mbe00751a98\" y=\"44.974219\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"208.497343\" xlink:href=\"#mbe00751a98\" y=\"44.974219\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"209.15187\" xlink:href=\"#mbe00751a98\" y=\"44.974219\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"209.611243\" xlink:href=\"#mbe00751a98\" y=\"44.974219\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"209.924922\" xlink:href=\"#mbe00751a98\" y=\"44.974219\"/>\n",
       "     <use style=\"fill:#ff7f0e;stroke:#ff7f0e;\" x=\"210.132825\" xlink:href=\"#mbe00751a98\" y=\"44.974219\"/>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"patch_3\">\n",
       "    <path d=\"M 42.620312 146.899219 \n",
       "L 42.620312 10.999219 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_4\">\n",
       "    <path d=\"M 237.920313 146.899219 \n",
       "L 237.920313 10.999219 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_5\">\n",
       "    <path d=\"M 42.620313 146.899219 \n",
       "L 237.920313 146.899219 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_6\">\n",
       "    <path d=\"M 42.620313 10.999219 \n",
       "L 237.920313 10.999219 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "  </g>\n",
       " </g>\n",
       " <defs>\n",
       "  <clipPath id=\"p554c127ca9\">\n",
       "   <rect height=\"135.9\" width=\"195.3\" x=\"42.620312\" y=\"10.999219\"/>\n",
       "  </clipPath>\n",
       " </defs>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<Figure size 252x180 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "import d2lzh as d2l\n",
    "import math\n",
    "from mxnet import nd\n",
    "\n",
    "def rmsprop_2d(x1, x2, s1, s2):\n",
    "    g1, g2, eps = 0.2 * x1, 4 * x2, 1e-6\n",
    "    s1 = gamma * s1 + (1 - gamma) * g1 ** 2\n",
    "    s2 = gamma * s2 + (1 - gamma) * g2 ** 2\n",
    "    x1 -= eta / math.sqrt(s1 + eps) * g1\n",
    "    x2 -= eta / math.sqrt(s2 + eps) * g2\n",
    "    return x1, x2, s1, s2\n",
    "\n",
    "def f_2d(x1, x2):\n",
    "    return 0.1 * x1 ** 2 + 2 * x2 ** 2\n",
    "\n",
    "eta, gamma = 0.4, 0.9\n",
    "d2l.show_trace_2d(f_2d, d2l.train_2d(rmsprop_2d))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 从零开始实现\n",
    "\n",
    "接下来按照RMSProp算法中的公式实现该算法。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "22"
    }
   },
   "outputs": [],
   "source": [
    "features, labels = d2l.get_data_ch7()\n",
    "\n",
    "def init_rmsprop_states():\n",
    "    s_w = nd.zeros((features.shape[1], 1))\n",
    "    s_b = nd.zeros(1)\n",
    "    return (s_w, s_b)\n",
    "\n",
    "def rmsprop(params, states, hyperparams):\n",
    "    gamma, eps = hyperparams['gamma'], 1e-6\n",
    "    for p, s in zip(params, states):\n",
    "        s[:] = gamma * s + (1 - gamma) * p.grad.square()\n",
    "        p[:] -= hyperparams['lr'] * p.grad / (s + eps).sqrt()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们将初始学习率设为0.01，并将超参数$\\gamma$设为0.9。此时，变量$\\boldsymbol{s}_t$可看作最近$1/(1-0.9) = 10$个时间步的平方项$\\boldsymbol{g}_t \\odot \\boldsymbol{g}_t$的加权平均。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "24"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss: 0.246587, 0.382151 sec per epoch\n"
     ]
    },
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       "  \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Created with matplotlib (http://matplotlib.org/) -->\n",
       "<svg height=\"184.15625pt\" version=\"1.1\" viewBox=\"0 0 249.78125 184.15625\" width=\"249.78125pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       " <defs>\n",
       "  <style type=\"text/css\">\n",
       "*{stroke-linecap:butt;stroke-linejoin:round;}\n",
       "  </style>\n",
       " </defs>\n",
       " <g id=\"figure_1\">\n",
       "  <g id=\"patch_1\">\n",
       "   <path d=\"M 0 184.15625 \n",
       "L 249.78125 184.15625 \n",
       "L 249.78125 -0 \n",
       "L 0 -0 \n",
       "z\n",
       "\" style=\"fill:none;\"/>\n",
       "  </g>\n",
       "  <g id=\"axes_1\">\n",
       "   <g id=\"patch_2\">\n",
       "    <path d=\"M 43.78125 146.6 \n",
       "L 239.08125 146.6 \n",
       "L 239.08125 10.7 \n",
       "L 43.78125 10.7 \n",
       "z\n",
       "\" style=\"fill:#ffffff;\"/>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_1\">\n",
       "    <g id=\"xtick_1\">\n",
       "     <g id=\"line2d_1\">\n",
       "      <defs>\n",
       "       <path d=\"M 0 0 \n",
       "L 0 3.5 \n",
       "\" id=\"mc253bdbfd8\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"52.658523\" xlink:href=\"#mc253bdbfd8\" y=\"146.6\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_1\">\n",
       "      <!-- 0.0 -->\n",
       "      <defs>\n",
       "       <path d=\"M 31.78125 66.40625 \n",
       "Q 24.171875 66.40625 20.328125 58.90625 \n",
       "Q 16.5 51.421875 16.5 36.375 \n",
       "Q 16.5 21.390625 20.328125 13.890625 \n",
       "Q 24.171875 6.390625 31.78125 6.390625 \n",
       "Q 39.453125 6.390625 43.28125 13.890625 \n",
       "Q 47.125 21.390625 47.125 36.375 \n",
       "Q 47.125 51.421875 43.28125 58.90625 \n",
       "Q 39.453125 66.40625 31.78125 66.40625 \n",
       "z\n",
       "M 31.78125 74.21875 \n",
       "Q 44.046875 74.21875 50.515625 64.515625 \n",
       "Q 56.984375 54.828125 56.984375 36.375 \n",
       "Q 56.984375 17.96875 50.515625 8.265625 \n",
       "Q 44.046875 -1.421875 31.78125 -1.421875 \n",
       "Q 19.53125 -1.421875 13.0625 8.265625 \n",
       "Q 6.59375 17.96875 6.59375 36.375 \n",
       "Q 6.59375 54.828125 13.0625 64.515625 \n",
       "Q 19.53125 74.21875 31.78125 74.21875 \n",
       "z\n",
       "\" id=\"DejaVuSans-30\"/>\n",
       "       <path d=\"M 10.6875 12.40625 \n",
       "L 21 12.40625 \n",
       "L 21 0 \n",
       "L 10.6875 0 \n",
       "z\n",
       "\" id=\"DejaVuSans-2e\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(44.70696 161.198437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-30\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_2\">\n",
       "     <g id=\"line2d_2\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"97.044886\" xlink:href=\"#mc253bdbfd8\" y=\"146.6\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_2\">\n",
       "      <!-- 0.5 -->\n",
       "      <defs>\n",
       "       <path d=\"M 10.796875 72.90625 \n",
       "L 49.515625 72.90625 \n",
       "L 49.515625 64.59375 \n",
       "L 19.828125 64.59375 \n",
       "L 19.828125 46.734375 \n",
       "Q 21.96875 47.46875 24.109375 47.828125 \n",
       "Q 26.265625 48.1875 28.421875 48.1875 \n",
       "Q 40.625 48.1875 47.75 41.5 \n",
       "Q 54.890625 34.8125 54.890625 23.390625 \n",
       "Q 54.890625 11.625 47.5625 5.09375 \n",
       "Q 40.234375 -1.421875 26.90625 -1.421875 \n",
       "Q 22.3125 -1.421875 17.546875 -0.640625 \n",
       "Q 12.796875 0.140625 7.71875 1.703125 \n",
       "L 7.71875 11.625 \n",
       "Q 12.109375 9.234375 16.796875 8.0625 \n",
       "Q 21.484375 6.890625 26.703125 6.890625 \n",
       "Q 35.15625 6.890625 40.078125 11.328125 \n",
       "Q 45.015625 15.765625 45.015625 23.390625 \n",
       "Q 45.015625 31 40.078125 35.4375 \n",
       "Q 35.15625 39.890625 26.703125 39.890625 \n",
       "Q 22.75 39.890625 18.8125 39.015625 \n",
       "Q 14.890625 38.140625 10.796875 36.28125 \n",
       "z\n",
       "\" id=\"DejaVuSans-35\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(89.093324 161.198437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-35\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_3\">\n",
       "     <g id=\"line2d_3\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"141.43125\" xlink:href=\"#mc253bdbfd8\" y=\"146.6\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_3\">\n",
       "      <!-- 1.0 -->\n",
       "      <defs>\n",
       "       <path d=\"M 12.40625 8.296875 \n",
       "L 28.515625 8.296875 \n",
       "L 28.515625 63.921875 \n",
       "L 10.984375 60.40625 \n",
       "L 10.984375 69.390625 \n",
       "L 28.421875 72.90625 \n",
       "L 38.28125 72.90625 \n",
       "L 38.28125 8.296875 \n",
       "L 54.390625 8.296875 \n",
       "L 54.390625 0 \n",
       "L 12.40625 0 \n",
       "z\n",
       "\" id=\"DejaVuSans-31\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(133.479688 161.198437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-31\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-30\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_4\">\n",
       "     <g id=\"line2d_4\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"185.817614\" xlink:href=\"#mc253bdbfd8\" y=\"146.6\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_4\">\n",
       "      <!-- 1.5 -->\n",
       "      <g transform=\"translate(177.866051 161.198437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-31\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-35\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_5\">\n",
       "     <g id=\"line2d_5\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"230.203977\" xlink:href=\"#mc253bdbfd8\" y=\"146.6\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_5\">\n",
       "      <!-- 2.0 -->\n",
       "      <defs>\n",
       "       <path d=\"M 19.1875 8.296875 \n",
       "L 53.609375 8.296875 \n",
       "L 53.609375 0 \n",
       "L 7.328125 0 \n",
       "L 7.328125 8.296875 \n",
       "Q 12.9375 14.109375 22.625 23.890625 \n",
       "Q 32.328125 33.6875 34.8125 36.53125 \n",
       "Q 39.546875 41.84375 41.421875 45.53125 \n",
       "Q 43.3125 49.21875 43.3125 52.78125 \n",
       "Q 43.3125 58.59375 39.234375 62.25 \n",
       "Q 35.15625 65.921875 28.609375 65.921875 \n",
       "Q 23.96875 65.921875 18.8125 64.3125 \n",
       "Q 13.671875 62.703125 7.8125 59.421875 \n",
       "L 7.8125 69.390625 \n",
       "Q 13.765625 71.78125 18.9375 73 \n",
       "Q 24.125 74.21875 28.421875 74.21875 \n",
       "Q 39.75 74.21875 46.484375 68.546875 \n",
       "Q 53.21875 62.890625 53.21875 53.421875 \n",
       "Q 53.21875 48.921875 51.53125 44.890625 \n",
       "Q 49.859375 40.875 45.40625 35.40625 \n",
       "Q 44.1875 33.984375 37.640625 27.21875 \n",
       "Q 31.109375 20.453125 19.1875 8.296875 \n",
       "z\n",
       "\" id=\"DejaVuSans-32\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(222.252415 161.198437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-32\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-30\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"text_6\">\n",
       "     <!-- epoch -->\n",
       "     <defs>\n",
       "      <path d=\"M 56.203125 29.59375 \n",
       "L 56.203125 25.203125 \n",
       "L 14.890625 25.203125 \n",
       "Q 15.484375 15.921875 20.484375 11.0625 \n",
       "Q 25.484375 6.203125 34.421875 6.203125 \n",
       "Q 39.59375 6.203125 44.453125 7.46875 \n",
       "Q 49.3125 8.734375 54.109375 11.28125 \n",
       "L 54.109375 2.78125 \n",
       "Q 49.265625 0.734375 44.1875 -0.34375 \n",
       "Q 39.109375 -1.421875 33.890625 -1.421875 \n",
       "Q 20.796875 -1.421875 13.15625 6.1875 \n",
       "Q 5.515625 13.8125 5.515625 26.8125 \n",
       "Q 5.515625 40.234375 12.765625 48.109375 \n",
       "Q 20.015625 56 32.328125 56 \n",
       "Q 43.359375 56 49.78125 48.890625 \n",
       "Q 56.203125 41.796875 56.203125 29.59375 \n",
       "z\n",
       "M 47.21875 32.234375 \n",
       "Q 47.125 39.59375 43.09375 43.984375 \n",
       "Q 39.0625 48.390625 32.421875 48.390625 \n",
       "Q 24.90625 48.390625 20.390625 44.140625 \n",
       "Q 15.875 39.890625 15.1875 32.171875 \n",
       "z\n",
       "\" id=\"DejaVuSans-65\"/>\n",
       "      <path d=\"M 18.109375 8.203125 \n",
       "L 18.109375 -20.796875 \n",
       "L 9.078125 -20.796875 \n",
       "L 9.078125 54.6875 \n",
       "L 18.109375 54.6875 \n",
       "L 18.109375 46.390625 \n",
       "Q 20.953125 51.265625 25.265625 53.625 \n",
       "Q 29.59375 56 35.59375 56 \n",
       "Q 45.5625 56 51.78125 48.09375 \n",
       "Q 58.015625 40.1875 58.015625 27.296875 \n",
       "Q 58.015625 14.40625 51.78125 6.484375 \n",
       "Q 45.5625 -1.421875 35.59375 -1.421875 \n",
       "Q 29.59375 -1.421875 25.265625 0.953125 \n",
       "Q 20.953125 3.328125 18.109375 8.203125 \n",
       "z\n",
       "M 48.6875 27.296875 \n",
       "Q 48.6875 37.203125 44.609375 42.84375 \n",
       "Q 40.53125 48.484375 33.40625 48.484375 \n",
       "Q 26.265625 48.484375 22.1875 42.84375 \n",
       "Q 18.109375 37.203125 18.109375 27.296875 \n",
       "Q 18.109375 17.390625 22.1875 11.75 \n",
       "Q 26.265625 6.109375 33.40625 6.109375 \n",
       "Q 40.53125 6.109375 44.609375 11.75 \n",
       "Q 48.6875 17.390625 48.6875 27.296875 \n",
       "z\n",
       "\" id=\"DejaVuSans-70\"/>\n",
       "      <path d=\"M 30.609375 48.390625 \n",
       "Q 23.390625 48.390625 19.1875 42.75 \n",
       "Q 14.984375 37.109375 14.984375 27.296875 \n",
       "Q 14.984375 17.484375 19.15625 11.84375 \n",
       "Q 23.34375 6.203125 30.609375 6.203125 \n",
       "Q 37.796875 6.203125 41.984375 11.859375 \n",
       "Q 46.1875 17.53125 46.1875 27.296875 \n",
       "Q 46.1875 37.015625 41.984375 42.703125 \n",
       "Q 37.796875 48.390625 30.609375 48.390625 \n",
       "z\n",
       "M 30.609375 56 \n",
       "Q 42.328125 56 49.015625 48.375 \n",
       "Q 55.71875 40.765625 55.71875 27.296875 \n",
       "Q 55.71875 13.875 49.015625 6.21875 \n",
       "Q 42.328125 -1.421875 30.609375 -1.421875 \n",
       "Q 18.84375 -1.421875 12.171875 6.21875 \n",
       "Q 5.515625 13.875 5.515625 27.296875 \n",
       "Q 5.515625 40.765625 12.171875 48.375 \n",
       "Q 18.84375 56 30.609375 56 \n",
       "z\n",
       "\" id=\"DejaVuSans-6f\"/>\n",
       "      <path d=\"M 48.78125 52.59375 \n",
       "L 48.78125 44.1875 \n",
       "Q 44.96875 46.296875 41.140625 47.34375 \n",
       "Q 37.3125 48.390625 33.40625 48.390625 \n",
       "Q 24.65625 48.390625 19.8125 42.84375 \n",
       "Q 14.984375 37.3125 14.984375 27.296875 \n",
       "Q 14.984375 17.28125 19.8125 11.734375 \n",
       "Q 24.65625 6.203125 33.40625 6.203125 \n",
       "Q 37.3125 6.203125 41.140625 7.25 \n",
       "Q 44.96875 8.296875 48.78125 10.40625 \n",
       "L 48.78125 2.09375 \n",
       "Q 45.015625 0.34375 40.984375 -0.53125 \n",
       "Q 36.96875 -1.421875 32.421875 -1.421875 \n",
       "Q 20.0625 -1.421875 12.78125 6.34375 \n",
       "Q 5.515625 14.109375 5.515625 27.296875 \n",
       "Q 5.515625 40.671875 12.859375 48.328125 \n",
       "Q 20.21875 56 33.015625 56 \n",
       "Q 37.15625 56 41.109375 55.140625 \n",
       "Q 45.0625 54.296875 48.78125 52.59375 \n",
       "z\n",
       "\" id=\"DejaVuSans-63\"/>\n",
       "      <path d=\"M 54.890625 33.015625 \n",
       "L 54.890625 0 \n",
       "L 45.90625 0 \n",
       "L 45.90625 32.71875 \n",
       "Q 45.90625 40.484375 42.875 44.328125 \n",
       "Q 39.84375 48.1875 33.796875 48.1875 \n",
       "Q 26.515625 48.1875 22.3125 43.546875 \n",
       "Q 18.109375 38.921875 18.109375 30.90625 \n",
       "L 18.109375 0 \n",
       "L 9.078125 0 \n",
       "L 9.078125 75.984375 \n",
       "L 18.109375 75.984375 \n",
       "L 18.109375 46.1875 \n",
       "Q 21.34375 51.125 25.703125 53.5625 \n",
       "Q 30.078125 56 35.796875 56 \n",
       "Q 45.21875 56 50.046875 50.171875 \n",
       "Q 54.890625 44.34375 54.890625 33.015625 \n",
       "z\n",
       "\" id=\"DejaVuSans-68\"/>\n",
       "     </defs>\n",
       "     <g transform=\"translate(126.203125 174.876562)scale(0.1 -0.1)\">\n",
       "      <use xlink:href=\"#DejaVuSans-65\"/>\n",
       "      <use x=\"61.523438\" xlink:href=\"#DejaVuSans-70\"/>\n",
       "      <use x=\"125\" xlink:href=\"#DejaVuSans-6f\"/>\n",
       "      <use x=\"186.181641\" xlink:href=\"#DejaVuSans-63\"/>\n",
       "      <use x=\"241.162109\" xlink:href=\"#DejaVuSans-68\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_2\">\n",
       "    <g id=\"ytick_1\">\n",
       "     <g id=\"line2d_6\">\n",
       "      <defs>\n",
       "       <path d=\"M 0 0 \n",
       "L -3.5 0 \n",
       "\" id=\"me0ad9e5efd\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"43.78125\" xlink:href=\"#me0ad9e5efd\" y=\"114.597312\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_7\">\n",
       "      <!-- 0.3 -->\n",
       "      <defs>\n",
       "       <path d=\"M 40.578125 39.3125 \n",
       "Q 47.65625 37.796875 51.625 33 \n",
       "Q 55.609375 28.21875 55.609375 21.1875 \n",
       "Q 55.609375 10.40625 48.1875 4.484375 \n",
       "Q 40.765625 -1.421875 27.09375 -1.421875 \n",
       "Q 22.515625 -1.421875 17.65625 -0.515625 \n",
       "Q 12.796875 0.390625 7.625 2.203125 \n",
       "L 7.625 11.71875 \n",
       "Q 11.71875 9.328125 16.59375 8.109375 \n",
       "Q 21.484375 6.890625 26.8125 6.890625 \n",
       "Q 36.078125 6.890625 40.9375 10.546875 \n",
       "Q 45.796875 14.203125 45.796875 21.1875 \n",
       "Q 45.796875 27.640625 41.28125 31.265625 \n",
       "Q 36.765625 34.90625 28.71875 34.90625 \n",
       "L 20.21875 34.90625 \n",
       "L 20.21875 43.015625 \n",
       "L 29.109375 43.015625 \n",
       "Q 36.375 43.015625 40.234375 45.921875 \n",
       "Q 44.09375 48.828125 44.09375 54.296875 \n",
       "Q 44.09375 59.90625 40.109375 62.90625 \n",
       "Q 36.140625 65.921875 28.71875 65.921875 \n",
       "Q 24.65625 65.921875 20.015625 65.03125 \n",
       "Q 15.375 64.15625 9.8125 62.3125 \n",
       "L 9.8125 71.09375 \n",
       "Q 15.4375 72.65625 20.34375 73.4375 \n",
       "Q 25.25 74.21875 29.59375 74.21875 \n",
       "Q 40.828125 74.21875 47.359375 69.109375 \n",
       "Q 53.90625 64.015625 53.90625 55.328125 \n",
       "Q 53.90625 49.265625 50.4375 45.09375 \n",
       "Q 46.96875 40.921875 40.578125 39.3125 \n",
       "z\n",
       "\" id=\"DejaVuSans-33\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(20.878125 118.396531)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-33\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_2\">\n",
       "     <g id=\"line2d_7\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"43.78125\" xlink:href=\"#me0ad9e5efd\" y=\"69.616795\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_8\">\n",
       "      <!-- 0.4 -->\n",
       "      <defs>\n",
       "       <path d=\"M 37.796875 64.3125 \n",
       "L 12.890625 25.390625 \n",
       "L 37.796875 25.390625 \n",
       "z\n",
       "M 35.203125 72.90625 \n",
       "L 47.609375 72.90625 \n",
       "L 47.609375 25.390625 \n",
       "L 58.015625 25.390625 \n",
       "L 58.015625 17.1875 \n",
       "L 47.609375 17.1875 \n",
       "L 47.609375 0 \n",
       "L 37.796875 0 \n",
       "L 37.796875 17.1875 \n",
       "L 4.890625 17.1875 \n",
       "L 4.890625 26.703125 \n",
       "z\n",
       "\" id=\"DejaVuSans-34\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(20.878125 73.416014)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-34\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_3\">\n",
       "     <g id=\"line2d_8\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"43.78125\" xlink:href=\"#me0ad9e5efd\" y=\"24.636278\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_9\">\n",
       "      <!-- 0.5 -->\n",
       "      <g transform=\"translate(20.878125 28.435497)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-35\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"text_10\">\n",
       "     <!-- loss -->\n",
       "     <defs>\n",
       "      <path d=\"M 9.421875 75.984375 \n",
       "L 18.40625 75.984375 \n",
       "L 18.40625 0 \n",
       "L 9.421875 0 \n",
       "z\n",
       "\" id=\"DejaVuSans-6c\"/>\n",
       "      <path d=\"M 44.28125 53.078125 \n",
       "L 44.28125 44.578125 \n",
       "Q 40.484375 46.53125 36.375 47.5 \n",
       "Q 32.28125 48.484375 27.875 48.484375 \n",
       "Q 21.1875 48.484375 17.84375 46.4375 \n",
       "Q 14.5 44.390625 14.5 40.28125 \n",
       "Q 14.5 37.15625 16.890625 35.375 \n",
       "Q 19.28125 33.59375 26.515625 31.984375 \n",
       "L 29.59375 31.296875 \n",
       "Q 39.15625 29.25 43.1875 25.515625 \n",
       "Q 47.21875 21.78125 47.21875 15.09375 \n",
       "Q 47.21875 7.46875 41.1875 3.015625 \n",
       "Q 35.15625 -1.421875 24.609375 -1.421875 \n",
       "Q 20.21875 -1.421875 15.453125 -0.5625 \n",
       "Q 10.6875 0.296875 5.421875 2 \n",
       "L 5.421875 11.28125 \n",
       "Q 10.40625 8.6875 15.234375 7.390625 \n",
       "Q 20.0625 6.109375 24.8125 6.109375 \n",
       "Q 31.15625 6.109375 34.5625 8.28125 \n",
       "Q 37.984375 10.453125 37.984375 14.40625 \n",
       "Q 37.984375 18.0625 35.515625 20.015625 \n",
       "Q 33.0625 21.96875 24.703125 23.78125 \n",
       "L 21.578125 24.515625 \n",
       "Q 13.234375 26.265625 9.515625 29.90625 \n",
       "Q 5.8125 33.546875 5.8125 39.890625 \n",
       "Q 5.8125 47.609375 11.28125 51.796875 \n",
       "Q 16.75 56 26.8125 56 \n",
       "Q 31.78125 56 36.171875 55.265625 \n",
       "Q 40.578125 54.546875 44.28125 53.078125 \n",
       "z\n",
       "\" id=\"DejaVuSans-73\"/>\n",
       "     </defs>\n",
       "     <g transform=\"translate(14.798437 88.307812)rotate(-90)scale(0.1 -0.1)\">\n",
       "      <use xlink:href=\"#DejaVuSans-6c\"/>\n",
       "      <use x=\"27.783203\" xlink:href=\"#DejaVuSans-6f\"/>\n",
       "      <use x=\"88.964844\" xlink:href=\"#DejaVuSans-73\"/>\n",
       "      <use x=\"141.064453\" xlink:href=\"#DejaVuSans-73\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"line2d_9\">\n",
       "    <path clip-path=\"url(#pc108231f23)\" d=\"M 52.658523 16.877273 \n",
       "L 58.576705 59.425688 \n",
       "L 64.494886 92.331635 \n",
       "L 70.413068 105.688641 \n",
       "L 76.33125 118.87957 \n",
       "L 82.249432 119.961293 \n",
       "L 88.167614 127.778424 \n",
       "L 94.085795 128.101932 \n",
       "L 100.003977 133.256568 \n",
       "L 105.922159 135.523247 \n",
       "L 111.840341 137.242643 \n",
       "L 117.758523 138.547194 \n",
       "L 123.676705 139.147681 \n",
       "L 129.594886 138.979003 \n",
       "L 135.513068 139.343947 \n",
       "L 141.43125 138.483693 \n",
       "L 147.349432 139.746245 \n",
       "L 153.267614 139.396141 \n",
       "L 159.185795 139.616784 \n",
       "L 165.103977 139.286714 \n",
       "L 171.022159 139.501888 \n",
       "L 176.940341 140.046697 \n",
       "L 182.858523 139.794102 \n",
       "L 188.776705 138.866051 \n",
       "L 194.694886 137.04572 \n",
       "L 200.613068 135.223666 \n",
       "L 206.53125 140.422727 \n",
       "L 212.449432 139.278751 \n",
       "L 218.367614 138.614361 \n",
       "L 224.285795 138.415132 \n",
       "L 230.203977 138.622893 \n",
       "\" style=\"fill:none;stroke:#1f77b4;stroke-linecap:square;stroke-width:1.5;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_3\">\n",
       "    <path d=\"M 43.78125 146.6 \n",
       "L 43.78125 10.7 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_4\">\n",
       "    <path d=\"M 239.08125 146.6 \n",
       "L 239.08125 10.7 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_5\">\n",
       "    <path d=\"M 43.78125 146.6 \n",
       "L 239.08125 146.6 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_6\">\n",
       "    <path d=\"M 43.78125 10.7 \n",
       "L 239.08125 10.7 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "  </g>\n",
       " </g>\n",
       " <defs>\n",
       "  <clipPath id=\"pc108231f23\">\n",
       "   <rect height=\"135.9\" width=\"195.3\" x=\"43.78125\" y=\"10.7\"/>\n",
       "  </clipPath>\n",
       " </defs>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<Figure size 252x180 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "d2l.train_ch7(rmsprop, init_rmsprop_states(), {'lr': 0.01, 'gamma': 0.9},\n",
    "              features, labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 简洁实现\n",
    "\n",
    "通过名称为“rmsprop”的`Trainer`实例，我们便可使用Gluon提供的RMSProp算法来训练模型。注意，超参数$\\gamma$通过`gamma1`指定。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "29"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss: 0.247679, 0.256022 sec per epoch\n"
     ]
    },
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"utf-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       "  \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Created with matplotlib (http://matplotlib.org/) -->\n",
       "<svg height=\"184.15625pt\" version=\"1.1\" viewBox=\"0 0 256.14375 184.15625\" width=\"256.14375pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       " <defs>\n",
       "  <style type=\"text/css\">\n",
       "*{stroke-linecap:butt;stroke-linejoin:round;}\n",
       "  </style>\n",
       " </defs>\n",
       " <g id=\"figure_1\">\n",
       "  <g id=\"patch_1\">\n",
       "   <path d=\"M 0 184.15625 \n",
       "L 256.14375 184.15625 \n",
       "L 256.14375 -0 \n",
       "L 0 -0 \n",
       "z\n",
       "\" style=\"fill:none;\"/>\n",
       "  </g>\n",
       "  <g id=\"axes_1\">\n",
       "   <g id=\"patch_2\">\n",
       "    <path d=\"M 50.14375 146.6 \n",
       "L 245.44375 146.6 \n",
       "L 245.44375 10.7 \n",
       "L 50.14375 10.7 \n",
       "z\n",
       "\" style=\"fill:#ffffff;\"/>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_1\">\n",
       "    <g id=\"xtick_1\">\n",
       "     <g id=\"line2d_1\">\n",
       "      <defs>\n",
       "       <path d=\"M 0 0 \n",
       "L 0 3.5 \n",
       "\" id=\"m920d6e4fcb\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"59.021023\" xlink:href=\"#m920d6e4fcb\" y=\"146.6\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_1\">\n",
       "      <!-- 0.0 -->\n",
       "      <defs>\n",
       "       <path d=\"M 31.78125 66.40625 \n",
       "Q 24.171875 66.40625 20.328125 58.90625 \n",
       "Q 16.5 51.421875 16.5 36.375 \n",
       "Q 16.5 21.390625 20.328125 13.890625 \n",
       "Q 24.171875 6.390625 31.78125 6.390625 \n",
       "Q 39.453125 6.390625 43.28125 13.890625 \n",
       "Q 47.125 21.390625 47.125 36.375 \n",
       "Q 47.125 51.421875 43.28125 58.90625 \n",
       "Q 39.453125 66.40625 31.78125 66.40625 \n",
       "z\n",
       "M 31.78125 74.21875 \n",
       "Q 44.046875 74.21875 50.515625 64.515625 \n",
       "Q 56.984375 54.828125 56.984375 36.375 \n",
       "Q 56.984375 17.96875 50.515625 8.265625 \n",
       "Q 44.046875 -1.421875 31.78125 -1.421875 \n",
       "Q 19.53125 -1.421875 13.0625 8.265625 \n",
       "Q 6.59375 17.96875 6.59375 36.375 \n",
       "Q 6.59375 54.828125 13.0625 64.515625 \n",
       "Q 19.53125 74.21875 31.78125 74.21875 \n",
       "z\n",
       "\" id=\"DejaVuSans-30\"/>\n",
       "       <path d=\"M 10.6875 12.40625 \n",
       "L 21 12.40625 \n",
       "L 21 0 \n",
       "L 10.6875 0 \n",
       "z\n",
       "\" id=\"DejaVuSans-2e\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(51.06946 161.198437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-30\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_2\">\n",
       "     <g id=\"line2d_2\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"103.407386\" xlink:href=\"#m920d6e4fcb\" y=\"146.6\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_2\">\n",
       "      <!-- 0.5 -->\n",
       "      <defs>\n",
       "       <path d=\"M 10.796875 72.90625 \n",
       "L 49.515625 72.90625 \n",
       "L 49.515625 64.59375 \n",
       "L 19.828125 64.59375 \n",
       "L 19.828125 46.734375 \n",
       "Q 21.96875 47.46875 24.109375 47.828125 \n",
       "Q 26.265625 48.1875 28.421875 48.1875 \n",
       "Q 40.625 48.1875 47.75 41.5 \n",
       "Q 54.890625 34.8125 54.890625 23.390625 \n",
       "Q 54.890625 11.625 47.5625 5.09375 \n",
       "Q 40.234375 -1.421875 26.90625 -1.421875 \n",
       "Q 22.3125 -1.421875 17.546875 -0.640625 \n",
       "Q 12.796875 0.140625 7.71875 1.703125 \n",
       "L 7.71875 11.625 \n",
       "Q 12.109375 9.234375 16.796875 8.0625 \n",
       "Q 21.484375 6.890625 26.703125 6.890625 \n",
       "Q 35.15625 6.890625 40.078125 11.328125 \n",
       "Q 45.015625 15.765625 45.015625 23.390625 \n",
       "Q 45.015625 31 40.078125 35.4375 \n",
       "Q 35.15625 39.890625 26.703125 39.890625 \n",
       "Q 22.75 39.890625 18.8125 39.015625 \n",
       "Q 14.890625 38.140625 10.796875 36.28125 \n",
       "z\n",
       "\" id=\"DejaVuSans-35\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(95.455824 161.198437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-35\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_3\">\n",
       "     <g id=\"line2d_3\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"147.79375\" xlink:href=\"#m920d6e4fcb\" y=\"146.6\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_3\">\n",
       "      <!-- 1.0 -->\n",
       "      <defs>\n",
       "       <path d=\"M 12.40625 8.296875 \n",
       "L 28.515625 8.296875 \n",
       "L 28.515625 63.921875 \n",
       "L 10.984375 60.40625 \n",
       "L 10.984375 69.390625 \n",
       "L 28.421875 72.90625 \n",
       "L 38.28125 72.90625 \n",
       "L 38.28125 8.296875 \n",
       "L 54.390625 8.296875 \n",
       "L 54.390625 0 \n",
       "L 12.40625 0 \n",
       "z\n",
       "\" id=\"DejaVuSans-31\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(139.842187 161.198437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-31\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-30\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_4\">\n",
       "     <g id=\"line2d_4\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"192.180114\" xlink:href=\"#m920d6e4fcb\" y=\"146.6\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_4\">\n",
       "      <!-- 1.5 -->\n",
       "      <g transform=\"translate(184.228551 161.198437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-31\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-35\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"xtick_5\">\n",
       "     <g id=\"line2d_5\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"236.566477\" xlink:href=\"#m920d6e4fcb\" y=\"146.6\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_5\">\n",
       "      <!-- 2.0 -->\n",
       "      <defs>\n",
       "       <path d=\"M 19.1875 8.296875 \n",
       "L 53.609375 8.296875 \n",
       "L 53.609375 0 \n",
       "L 7.328125 0 \n",
       "L 7.328125 8.296875 \n",
       "Q 12.9375 14.109375 22.625 23.890625 \n",
       "Q 32.328125 33.6875 34.8125 36.53125 \n",
       "Q 39.546875 41.84375 41.421875 45.53125 \n",
       "Q 43.3125 49.21875 43.3125 52.78125 \n",
       "Q 43.3125 58.59375 39.234375 62.25 \n",
       "Q 35.15625 65.921875 28.609375 65.921875 \n",
       "Q 23.96875 65.921875 18.8125 64.3125 \n",
       "Q 13.671875 62.703125 7.8125 59.421875 \n",
       "L 7.8125 69.390625 \n",
       "Q 13.765625 71.78125 18.9375 73 \n",
       "Q 24.125 74.21875 28.421875 74.21875 \n",
       "Q 39.75 74.21875 46.484375 68.546875 \n",
       "Q 53.21875 62.890625 53.21875 53.421875 \n",
       "Q 53.21875 48.921875 51.53125 44.890625 \n",
       "Q 49.859375 40.875 45.40625 35.40625 \n",
       "Q 44.1875 33.984375 37.640625 27.21875 \n",
       "Q 31.109375 20.453125 19.1875 8.296875 \n",
       "z\n",
       "\" id=\"DejaVuSans-32\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(228.614915 161.198437)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-32\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-30\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"text_6\">\n",
       "     <!-- epoch -->\n",
       "     <defs>\n",
       "      <path d=\"M 56.203125 29.59375 \n",
       "L 56.203125 25.203125 \n",
       "L 14.890625 25.203125 \n",
       "Q 15.484375 15.921875 20.484375 11.0625 \n",
       "Q 25.484375 6.203125 34.421875 6.203125 \n",
       "Q 39.59375 6.203125 44.453125 7.46875 \n",
       "Q 49.3125 8.734375 54.109375 11.28125 \n",
       "L 54.109375 2.78125 \n",
       "Q 49.265625 0.734375 44.1875 -0.34375 \n",
       "Q 39.109375 -1.421875 33.890625 -1.421875 \n",
       "Q 20.796875 -1.421875 13.15625 6.1875 \n",
       "Q 5.515625 13.8125 5.515625 26.8125 \n",
       "Q 5.515625 40.234375 12.765625 48.109375 \n",
       "Q 20.015625 56 32.328125 56 \n",
       "Q 43.359375 56 49.78125 48.890625 \n",
       "Q 56.203125 41.796875 56.203125 29.59375 \n",
       "z\n",
       "M 47.21875 32.234375 \n",
       "Q 47.125 39.59375 43.09375 43.984375 \n",
       "Q 39.0625 48.390625 32.421875 48.390625 \n",
       "Q 24.90625 48.390625 20.390625 44.140625 \n",
       "Q 15.875 39.890625 15.1875 32.171875 \n",
       "z\n",
       "\" id=\"DejaVuSans-65\"/>\n",
       "      <path d=\"M 18.109375 8.203125 \n",
       "L 18.109375 -20.796875 \n",
       "L 9.078125 -20.796875 \n",
       "L 9.078125 54.6875 \n",
       "L 18.109375 54.6875 \n",
       "L 18.109375 46.390625 \n",
       "Q 20.953125 51.265625 25.265625 53.625 \n",
       "Q 29.59375 56 35.59375 56 \n",
       "Q 45.5625 56 51.78125 48.09375 \n",
       "Q 58.015625 40.1875 58.015625 27.296875 \n",
       "Q 58.015625 14.40625 51.78125 6.484375 \n",
       "Q 45.5625 -1.421875 35.59375 -1.421875 \n",
       "Q 29.59375 -1.421875 25.265625 0.953125 \n",
       "Q 20.953125 3.328125 18.109375 8.203125 \n",
       "z\n",
       "M 48.6875 27.296875 \n",
       "Q 48.6875 37.203125 44.609375 42.84375 \n",
       "Q 40.53125 48.484375 33.40625 48.484375 \n",
       "Q 26.265625 48.484375 22.1875 42.84375 \n",
       "Q 18.109375 37.203125 18.109375 27.296875 \n",
       "Q 18.109375 17.390625 22.1875 11.75 \n",
       "Q 26.265625 6.109375 33.40625 6.109375 \n",
       "Q 40.53125 6.109375 44.609375 11.75 \n",
       "Q 48.6875 17.390625 48.6875 27.296875 \n",
       "z\n",
       "\" id=\"DejaVuSans-70\"/>\n",
       "      <path d=\"M 30.609375 48.390625 \n",
       "Q 23.390625 48.390625 19.1875 42.75 \n",
       "Q 14.984375 37.109375 14.984375 27.296875 \n",
       "Q 14.984375 17.484375 19.15625 11.84375 \n",
       "Q 23.34375 6.203125 30.609375 6.203125 \n",
       "Q 37.796875 6.203125 41.984375 11.859375 \n",
       "Q 46.1875 17.53125 46.1875 27.296875 \n",
       "Q 46.1875 37.015625 41.984375 42.703125 \n",
       "Q 37.796875 48.390625 30.609375 48.390625 \n",
       "z\n",
       "M 30.609375 56 \n",
       "Q 42.328125 56 49.015625 48.375 \n",
       "Q 55.71875 40.765625 55.71875 27.296875 \n",
       "Q 55.71875 13.875 49.015625 6.21875 \n",
       "Q 42.328125 -1.421875 30.609375 -1.421875 \n",
       "Q 18.84375 -1.421875 12.171875 6.21875 \n",
       "Q 5.515625 13.875 5.515625 27.296875 \n",
       "Q 5.515625 40.765625 12.171875 48.375 \n",
       "Q 18.84375 56 30.609375 56 \n",
       "z\n",
       "\" id=\"DejaVuSans-6f\"/>\n",
       "      <path d=\"M 48.78125 52.59375 \n",
       "L 48.78125 44.1875 \n",
       "Q 44.96875 46.296875 41.140625 47.34375 \n",
       "Q 37.3125 48.390625 33.40625 48.390625 \n",
       "Q 24.65625 48.390625 19.8125 42.84375 \n",
       "Q 14.984375 37.3125 14.984375 27.296875 \n",
       "Q 14.984375 17.28125 19.8125 11.734375 \n",
       "Q 24.65625 6.203125 33.40625 6.203125 \n",
       "Q 37.3125 6.203125 41.140625 7.25 \n",
       "Q 44.96875 8.296875 48.78125 10.40625 \n",
       "L 48.78125 2.09375 \n",
       "Q 45.015625 0.34375 40.984375 -0.53125 \n",
       "Q 36.96875 -1.421875 32.421875 -1.421875 \n",
       "Q 20.0625 -1.421875 12.78125 6.34375 \n",
       "Q 5.515625 14.109375 5.515625 27.296875 \n",
       "Q 5.515625 40.671875 12.859375 48.328125 \n",
       "Q 20.21875 56 33.015625 56 \n",
       "Q 37.15625 56 41.109375 55.140625 \n",
       "Q 45.0625 54.296875 48.78125 52.59375 \n",
       "z\n",
       "\" id=\"DejaVuSans-63\"/>\n",
       "      <path d=\"M 54.890625 33.015625 \n",
       "L 54.890625 0 \n",
       "L 45.90625 0 \n",
       "L 45.90625 32.71875 \n",
       "Q 45.90625 40.484375 42.875 44.328125 \n",
       "Q 39.84375 48.1875 33.796875 48.1875 \n",
       "Q 26.515625 48.1875 22.3125 43.546875 \n",
       "Q 18.109375 38.921875 18.109375 30.90625 \n",
       "L 18.109375 0 \n",
       "L 9.078125 0 \n",
       "L 9.078125 75.984375 \n",
       "L 18.109375 75.984375 \n",
       "L 18.109375 46.1875 \n",
       "Q 21.34375 51.125 25.703125 53.5625 \n",
       "Q 30.078125 56 35.796875 56 \n",
       "Q 45.21875 56 50.046875 50.171875 \n",
       "Q 54.890625 44.34375 54.890625 33.015625 \n",
       "z\n",
       "\" id=\"DejaVuSans-68\"/>\n",
       "     </defs>\n",
       "     <g transform=\"translate(132.565625 174.876562)scale(0.1 -0.1)\">\n",
       "      <use xlink:href=\"#DejaVuSans-65\"/>\n",
       "      <use x=\"61.523438\" xlink:href=\"#DejaVuSans-70\"/>\n",
       "      <use x=\"125\" xlink:href=\"#DejaVuSans-6f\"/>\n",
       "      <use x=\"186.181641\" xlink:href=\"#DejaVuSans-63\"/>\n",
       "      <use x=\"241.162109\" xlink:href=\"#DejaVuSans-68\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"matplotlib.axis_2\">\n",
       "    <g id=\"ytick_1\">\n",
       "     <g id=\"line2d_6\">\n",
       "      <defs>\n",
       "       <path d=\"M 0 0 \n",
       "L -3.5 0 \n",
       "\" id=\"m095597e60d\" style=\"stroke:#000000;stroke-width:0.8;\"/>\n",
       "      </defs>\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"50.14375\" xlink:href=\"#m095597e60d\" y=\"136.597887\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_7\">\n",
       "      <!-- 0.25 -->\n",
       "      <g transform=\"translate(20.878125 140.397106)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-32\"/>\n",
       "       <use x=\"159.033203\" xlink:href=\"#DejaVuSans-35\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_2\">\n",
       "     <g id=\"line2d_7\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"50.14375\" xlink:href=\"#m095597e60d\" y=\"111.158976\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_8\">\n",
       "      <!-- 0.30 -->\n",
       "      <defs>\n",
       "       <path d=\"M 40.578125 39.3125 \n",
       "Q 47.65625 37.796875 51.625 33 \n",
       "Q 55.609375 28.21875 55.609375 21.1875 \n",
       "Q 55.609375 10.40625 48.1875 4.484375 \n",
       "Q 40.765625 -1.421875 27.09375 -1.421875 \n",
       "Q 22.515625 -1.421875 17.65625 -0.515625 \n",
       "Q 12.796875 0.390625 7.625 2.203125 \n",
       "L 7.625 11.71875 \n",
       "Q 11.71875 9.328125 16.59375 8.109375 \n",
       "Q 21.484375 6.890625 26.8125 6.890625 \n",
       "Q 36.078125 6.890625 40.9375 10.546875 \n",
       "Q 45.796875 14.203125 45.796875 21.1875 \n",
       "Q 45.796875 27.640625 41.28125 31.265625 \n",
       "Q 36.765625 34.90625 28.71875 34.90625 \n",
       "L 20.21875 34.90625 \n",
       "L 20.21875 43.015625 \n",
       "L 29.109375 43.015625 \n",
       "Q 36.375 43.015625 40.234375 45.921875 \n",
       "Q 44.09375 48.828125 44.09375 54.296875 \n",
       "Q 44.09375 59.90625 40.109375 62.90625 \n",
       "Q 36.140625 65.921875 28.71875 65.921875 \n",
       "Q 24.65625 65.921875 20.015625 65.03125 \n",
       "Q 15.375 64.15625 9.8125 62.3125 \n",
       "L 9.8125 71.09375 \n",
       "Q 15.4375 72.65625 20.34375 73.4375 \n",
       "Q 25.25 74.21875 29.59375 74.21875 \n",
       "Q 40.828125 74.21875 47.359375 69.109375 \n",
       "Q 53.90625 64.015625 53.90625 55.328125 \n",
       "Q 53.90625 49.265625 50.4375 45.09375 \n",
       "Q 46.96875 40.921875 40.578125 39.3125 \n",
       "z\n",
       "\" id=\"DejaVuSans-33\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(20.878125 114.958195)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-33\"/>\n",
       "       <use x=\"159.033203\" xlink:href=\"#DejaVuSans-30\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_3\">\n",
       "     <g id=\"line2d_8\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"50.14375\" xlink:href=\"#m095597e60d\" y=\"85.720064\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_9\">\n",
       "      <!-- 0.35 -->\n",
       "      <g transform=\"translate(20.878125 89.519283)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-33\"/>\n",
       "       <use x=\"159.033203\" xlink:href=\"#DejaVuSans-35\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_4\">\n",
       "     <g id=\"line2d_9\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"50.14375\" xlink:href=\"#m095597e60d\" y=\"60.281153\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_10\">\n",
       "      <!-- 0.40 -->\n",
       "      <defs>\n",
       "       <path d=\"M 37.796875 64.3125 \n",
       "L 12.890625 25.390625 \n",
       "L 37.796875 25.390625 \n",
       "z\n",
       "M 35.203125 72.90625 \n",
       "L 47.609375 72.90625 \n",
       "L 47.609375 25.390625 \n",
       "L 58.015625 25.390625 \n",
       "L 58.015625 17.1875 \n",
       "L 47.609375 17.1875 \n",
       "L 47.609375 0 \n",
       "L 37.796875 0 \n",
       "L 37.796875 17.1875 \n",
       "L 4.890625 17.1875 \n",
       "L 4.890625 26.703125 \n",
       "z\n",
       "\" id=\"DejaVuSans-34\"/>\n",
       "      </defs>\n",
       "      <g transform=\"translate(20.878125 64.080372)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-34\"/>\n",
       "       <use x=\"159.033203\" xlink:href=\"#DejaVuSans-30\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"ytick_5\">\n",
       "     <g id=\"line2d_10\">\n",
       "      <g>\n",
       "       <use style=\"stroke:#000000;stroke-width:0.8;\" x=\"50.14375\" xlink:href=\"#m095597e60d\" y=\"34.842241\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "     <g id=\"text_11\">\n",
       "      <!-- 0.45 -->\n",
       "      <g transform=\"translate(20.878125 38.64146)scale(0.1 -0.1)\">\n",
       "       <use xlink:href=\"#DejaVuSans-30\"/>\n",
       "       <use x=\"63.623047\" xlink:href=\"#DejaVuSans-2e\"/>\n",
       "       <use x=\"95.410156\" xlink:href=\"#DejaVuSans-34\"/>\n",
       "       <use x=\"159.033203\" xlink:href=\"#DejaVuSans-35\"/>\n",
       "      </g>\n",
       "     </g>\n",
       "    </g>\n",
       "    <g id=\"text_12\">\n",
       "     <!-- loss -->\n",
       "     <defs>\n",
       "      <path d=\"M 9.421875 75.984375 \n",
       "L 18.40625 75.984375 \n",
       "L 18.40625 0 \n",
       "L 9.421875 0 \n",
       "z\n",
       "\" id=\"DejaVuSans-6c\"/>\n",
       "      <path d=\"M 44.28125 53.078125 \n",
       "L 44.28125 44.578125 \n",
       "Q 40.484375 46.53125 36.375 47.5 \n",
       "Q 32.28125 48.484375 27.875 48.484375 \n",
       "Q 21.1875 48.484375 17.84375 46.4375 \n",
       "Q 14.5 44.390625 14.5 40.28125 \n",
       "Q 14.5 37.15625 16.890625 35.375 \n",
       "Q 19.28125 33.59375 26.515625 31.984375 \n",
       "L 29.59375 31.296875 \n",
       "Q 39.15625 29.25 43.1875 25.515625 \n",
       "Q 47.21875 21.78125 47.21875 15.09375 \n",
       "Q 47.21875 7.46875 41.1875 3.015625 \n",
       "Q 35.15625 -1.421875 24.609375 -1.421875 \n",
       "Q 20.21875 -1.421875 15.453125 -0.5625 \n",
       "Q 10.6875 0.296875 5.421875 2 \n",
       "L 5.421875 11.28125 \n",
       "Q 10.40625 8.6875 15.234375 7.390625 \n",
       "Q 20.0625 6.109375 24.8125 6.109375 \n",
       "Q 31.15625 6.109375 34.5625 8.28125 \n",
       "Q 37.984375 10.453125 37.984375 14.40625 \n",
       "Q 37.984375 18.0625 35.515625 20.015625 \n",
       "Q 33.0625 21.96875 24.703125 23.78125 \n",
       "L 21.578125 24.515625 \n",
       "Q 13.234375 26.265625 9.515625 29.90625 \n",
       "Q 5.8125 33.546875 5.8125 39.890625 \n",
       "Q 5.8125 47.609375 11.28125 51.796875 \n",
       "Q 16.75 56 26.8125 56 \n",
       "Q 31.78125 56 36.171875 55.265625 \n",
       "Q 40.578125 54.546875 44.28125 53.078125 \n",
       "z\n",
       "\" id=\"DejaVuSans-73\"/>\n",
       "     </defs>\n",
       "     <g transform=\"translate(14.798437 88.307812)rotate(-90)scale(0.1 -0.1)\">\n",
       "      <use xlink:href=\"#DejaVuSans-6c\"/>\n",
       "      <use x=\"27.783203\" xlink:href=\"#DejaVuSans-6f\"/>\n",
       "      <use x=\"88.964844\" xlink:href=\"#DejaVuSans-73\"/>\n",
       "      <use x=\"141.064453\" xlink:href=\"#DejaVuSans-73\"/>\n",
       "     </g>\n",
       "    </g>\n",
       "   </g>\n",
       "   <g id=\"line2d_11\">\n",
       "    <path clip-path=\"url(#p45befb89b8)\" d=\"M 59.021023 16.877273 \n",
       "L 64.939205 67.069417 \n",
       "L 70.857386 90.448017 \n",
       "L 76.775568 103.805434 \n",
       "L 82.69375 115.996668 \n",
       "L 88.611932 123.471005 \n",
       "L 94.530114 130.217287 \n",
       "L 100.448295 134.061292 \n",
       "L 106.366477 135.16073 \n",
       "L 112.284659 135.931575 \n",
       "L 118.202841 136.498723 \n",
       "L 124.121023 135.432613 \n",
       "L 130.039205 137.364002 \n",
       "L 135.957386 138.008526 \n",
       "L 141.875568 139.179425 \n",
       "L 147.79375 138.838149 \n",
       "L 153.711932 138.732033 \n",
       "L 159.630114 138.869786 \n",
       "L 165.548295 140.102785 \n",
       "L 171.466477 139.894434 \n",
       "L 177.384659 138.77743 \n",
       "L 183.302841 139.734171 \n",
       "L 189.221023 140.422727 \n",
       "L 195.139205 139.827187 \n",
       "L 201.057386 138.606182 \n",
       "L 206.975568 138.3268 \n",
       "L 212.89375 135.784359 \n",
       "L 218.811932 139.8657 \n",
       "L 224.730114 138.046251 \n",
       "L 230.648295 138.137 \n",
       "L 236.566477 137.778734 \n",
       "\" style=\"fill:none;stroke:#1f77b4;stroke-linecap:square;stroke-width:1.5;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_3\">\n",
       "    <path d=\"M 50.14375 146.6 \n",
       "L 50.14375 10.7 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_4\">\n",
       "    <path d=\"M 245.44375 146.6 \n",
       "L 245.44375 10.7 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_5\">\n",
       "    <path d=\"M 50.14375 146.6 \n",
       "L 245.44375 146.6 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "   <g id=\"patch_6\">\n",
       "    <path d=\"M 50.14375 10.7 \n",
       "L 245.44375 10.7 \n",
       "\" style=\"fill:none;stroke:#000000;stroke-linecap:square;stroke-linejoin:miter;stroke-width:0.8;\"/>\n",
       "   </g>\n",
       "  </g>\n",
       " </g>\n",
       " <defs>\n",
       "  <clipPath id=\"p45befb89b8\">\n",
       "   <rect height=\"135.9\" width=\"195.3\" x=\"50.14375\" y=\"10.7\"/>\n",
       "  </clipPath>\n",
       " </defs>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<Figure size 252x180 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "d2l.train_gluon_ch7('rmsprop', {'learning_rate': 0.01, 'gamma1': 0.9},\n",
    "                    features, labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 小结\n",
    "\n",
    "* RMSProp算法和AdaGrad算法的不同在于，RMSProp算法使用了小批量随机梯度按元素平方的指数加权移动平均来调整学习率。\n",
    "\n",
    "## 练习\n",
    "\n",
    "* 把$\\gamma$的值设为1，实验结果有什么变化？为什么？\n",
    "* 试着使用其他的初始学习率和$\\gamma$超参数的组合，观察并分析实验结果。\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "## 参考文献\n",
    "\n",
    "[1] Tieleman, T., & Hinton, G. (2012). Lecture 6.5-rmsprop: Divide the gradient by a running average of its recent magnitude. COURSERA: Neural networks for machine learning, 4(2), 26-31.\n",
    "\n",
    "## 扫码直达[讨论区](https://discuss.gluon.ai/t/topic/2275)\n",
    "\n",
    "![](../img/qr_rmsprop.svg)"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}